//===- Utils.cpp ---- Misc utilities for loop transformation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements miscellaneous loop transformation routines.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
#include <cstdint>

using namespace mlir;

#define DEBUG_TYPE "scf-utils"

SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
    ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
    bool replaceIterOperandsUsesInLoop) {
  if (loopNest.empty())
    return {};
  // This method is recursive (to make it more readable). Adding an
  // assertion here to limit the recursion. (See
  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
  assert(loopNest.size() <= 10 &&
         "exceeded recursion limit when yielding value from loop nest");

  // To yield a value from a perfectly nested loop nest, the following
  // pattern needs to be created, i.e. starting with
  //
  // ```mlir
  //  scf.for .. {
  //    scf.for .. {
  //      scf.for .. {
  //        %value = ...
  //      }
  //    }
  //  }
  // ```
  //
  // needs to be modified to
  //
  // ```mlir
  // %0 = scf.for .. iter_args(%arg0 = %init) {
  //   %1 = scf.for .. iter_args(%arg1 = %arg0) {
  //     %2 = scf.for .. iter_args(%arg2 = %arg1) {
  //       %value = ...
  //       scf.yield %value
  //     }
  //     scf.yield %2
  //   }
  //   scf.yield %1
  // }
  // ```
  //
  // The inner most loop is handled using the `replaceWithAdditionalYields`
  // that works on a single loop.
  if (loopNest.size() == 1) {
    auto innerMostLoop =
        cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
            rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
            newYieldValuesFn));
    return {innerMostLoop};
  }
  // The outer loops are modified by calling this method recursively
  // - The return value of the inner loop is the value yielded by this loop.
  // - The region iter args of this loop are the init_args for the inner loop.
  SmallVector<scf::ForOp> newLoopNest;
  NewYieldValuesFn fn =
      [&](OpBuilder &innerBuilder, Location loc,
          ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
    newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
                                               innerNewBBArgs, newYieldValuesFn,
                                               replaceIterOperandsUsesInLoop);
    return llvm::to_vector(llvm::map_range(
        newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
        [](OpResult r) -> Value { return r; }));
  };
  scf::ForOp outerMostLoop =
      cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
          rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
  return newLoopNest;
}

/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
/// This method also clones the `arith::ConstantIndexOp` at the start of
/// `outlinedFuncBody` to alloc simple canonicalizations. If `callOp` is
/// provided, it will be set to point to the operation that calls the outlined
/// function.
// TODO: support more than single-block regions.
// TODO: more flexible constant handling.
FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
                                                       Location loc,
                                                       Region &region,
                                                       StringRef funcName,
                                                       func::CallOp *callOp) {
  assert(!funcName.empty() && "funcName cannot be empty");
  if (!region.hasOneBlock())
    return failure();

  Block *originalBlock = &region.front();
  Operation *originalTerminator = originalBlock->getTerminator();

  // Outline before current function.
  OpBuilder::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());

  SetVector<Value> captures;
  getUsedValuesDefinedAbove(region, captures);

  ValueRange outlinedValues(captures.getArrayRef());
  SmallVector<Type> outlinedFuncArgTypes;
  SmallVector<Location> outlinedFuncArgLocs;
  // Region's arguments are exactly the first block's arguments as per
  // Region::getArguments().
  // Func's arguments are cat(regions's arguments, captures arguments).
  for (BlockArgument arg : region.getArguments()) {
    outlinedFuncArgTypes.push_back(arg.getType());
    outlinedFuncArgLocs.push_back(arg.getLoc());
  }
  for (Value value : outlinedValues) {
    outlinedFuncArgTypes.push_back(value.getType());
    outlinedFuncArgLocs.push_back(value.getLoc());
  }
  FunctionType outlinedFuncType =
      FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes,
                        originalTerminator->getOperandTypes());
  auto outlinedFunc =
      func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
  Block *outlinedFuncBody = outlinedFunc.addEntryBlock();

  // Merge blocks while replacing the original block operands.
  // Warning: `mergeBlocks` erases the original block, reconstruct it later.
  int64_t numOriginalBlockArguments = originalBlock->getNumArguments();
  auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
  {
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPointToEnd(outlinedFuncBody);
    rewriter.mergeBlocks(
        originalBlock, outlinedFuncBody,
        outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
    // Explicitly set up a new ReturnOp terminator.
    rewriter.setInsertionPointToEnd(outlinedFuncBody);
    func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(),
                           originalTerminator->getOperands());
  }

  // Reconstruct the block that was deleted and add a
  // terminator(call_results).
  Block *newBlock = rewriter.createBlock(
      &region, region.begin(),
      TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
      ArrayRef<Location>(outlinedFuncArgLocs)
          .take_front(numOriginalBlockArguments));
  {
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPointToEnd(newBlock);
    SmallVector<Value> callValues;
    llvm::append_range(callValues, newBlock->getArguments());
    llvm::append_range(callValues, outlinedValues);
    auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
    if (callOp)
      *callOp = call;

    // `originalTerminator` was moved to `outlinedFuncBody` and is still valid.
    // Clone `originalTerminator` to take the callOp results then erase it from
    // `outlinedFuncBody`.
    IRMapping bvm;
    bvm.map(originalTerminator->getOperands(), call->getResults());
    rewriter.clone(*originalTerminator, bvm);
    rewriter.eraseOp(originalTerminator);
  }

  // Lastly, explicit RAUW outlinedValues, only for uses within `outlinedFunc`.
  // Clone the `arith::ConstantIndexOp` at the start of `outlinedFuncBody`.
  for (auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
                                               outlinedValues.size()))) {
    Value orig = std::get<0>(it);
    Value repl = std::get<1>(it);
    {
      OpBuilder::InsertionGuard g(rewriter);
      rewriter.setInsertionPointToStart(outlinedFuncBody);
      if (Operation *cst = orig.getDefiningOp<arith::ConstantIndexOp>()) {
        repl = rewriter.clone(*cst)->getResult(0);
      }
    }
    orig.replaceUsesWithIf(repl, [&](OpOperand &opOperand) {
      return outlinedFunc->isProperAncestor(opOperand.getOwner());
    });
  }

  return outlinedFunc;
}

LogicalResult mlir::outlineIfOp(RewriterBase &b, scf::IfOp ifOp,
                                func::FuncOp *thenFn, StringRef thenFnName,
                                func::FuncOp *elseFn, StringRef elseFnName) {
  IRRewriter rewriter(b);
  Location loc = ifOp.getLoc();
  FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
  if (thenFn && !ifOp.getThenRegion().empty()) {
    outlinedFuncOpOrFailure = outlineSingleBlockRegion(
        rewriter, loc, ifOp.getThenRegion(), thenFnName);
    if (failed(outlinedFuncOpOrFailure))
      return failure();
    *thenFn = *outlinedFuncOpOrFailure;
  }
  if (elseFn && !ifOp.getElseRegion().empty()) {
    outlinedFuncOpOrFailure = outlineSingleBlockRegion(
        rewriter, loc, ifOp.getElseRegion(), elseFnName);
    if (failed(outlinedFuncOpOrFailure))
      return failure();
    *elseFn = *outlinedFuncOpOrFailure;
  }
  return success();
}

bool mlir::getInnermostParallelLoops(Operation *rootOp,
                                     SmallVectorImpl<scf::ParallelOp> &result) {
  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
  bool rootEnclosesPloops = false;
  for (Region &region : rootOp->getRegions()) {
    for (Block &block : region.getBlocks()) {
      for (Operation &op : block) {
        bool enclosesPloops = getInnermostParallelLoops(&op, result);
        rootEnclosesPloops |= enclosesPloops;
        if (auto ploop = dyn_cast<scf::ParallelOp>(op)) {
          rootEnclosesPloops = true;

          // Collect parallel loop if it is an innermost one.
          if (!enclosesPloops)
            result.push_back(ploop);
        }
      }
    }
  }
  return rootEnclosesPloops;
}

// Build the IR that performs ceil division of a positive value by a constant:
//    ceildiv(a, B) = divis(a + (B-1), B)
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
                             int64_t divisor) {
  assert(divisor > 0 && "expected positive divisor");
  assert(dividend.getType().isIntOrIndex() &&
         "expected integer or index-typed value");

  Value divisorMinusOneCst = arith::ConstantOp::create(
      builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
  Value divisorCst = arith::ConstantOp::create(
      builder, loc, builder.getIntegerAttr(dividend.getType(), divisor));
  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
  return arith::DivUIOp::create(builder, loc, sum, divisorCst);
}

// Build the IR that performs ceil division of a positive value by another
// positive value:
//    ceildiv(a, b) = divis(a + (b - 1), b)
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
                             Value divisor) {
  assert(dividend.getType().isIntOrIndex() &&
         "expected integer or index-typed value");
  Value cstOne = arith::ConstantOp::create(
      builder, loc, builder.getOneAttr(dividend.getType()));
  Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
  Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
  return arith::DivUIOp::create(builder, loc, sum, divisor);
}

/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
/// unrolled iteration using annotateFn.
static void generateUnrolledLoop(
    Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
    function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
    function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
    ValueRange iterArgs, ValueRange yieldedValues) {
  // Builder to insert unrolled bodies just before the terminator of the body of
  // 'forOp'.
  auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);

  constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
  if (!annotateFn)
    annotateFn = defaultAnnotateFn;

  // Keep a pointer to the last non-terminator operation in the original block
  // so that we know what to clone (since we are doing this in-place).
  Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);

  // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
  SmallVector<Value, 4> lastYielded(yieldedValues);

  for (unsigned i = 1; i < unrollFactor; i++) {
    IRMapping operandMap;

    // Prepare operand map.
    operandMap.map(iterArgs, lastYielded);

    // If the induction variable is used, create a remapping to the value for
    // this unrolled instance.
    if (!forOpIV.use_empty()) {
      Value ivUnroll = ivRemapFn(i, forOpIV, builder);
      operandMap.map(forOpIV, ivUnroll);
    }

    // Clone the original body of 'forOp'.
    for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
      Operation *clonedOp = builder.clone(*it, operandMap);
      annotateFn(i, clonedOp, builder);
    }

    // Update yielded values.
    for (unsigned i = 0, e = lastYielded.size(); i < e; i++)
      lastYielded[i] = operandMap.lookupOrDefault(yieldedValues[i]);
  }

  // Make sure we annotate the Ops in the original body. We do this last so that
  // any annotations are not copied into the cloned Ops above.
  for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++)
    annotateFn(0, &*it, builder);

  // Update operands of the yield statement.
  loopBodyBlock->getTerminator()->setOperands(lastYielded);
}

/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
/// epilogue loop, if the loop is unrolled.
FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
    scf::ForOp forOp, uint64_t unrollFactor,
    function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
  assert(unrollFactor > 0 && "expected positive unroll factor");

  // Return if the loop body is empty.
  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
    return UnrolledLoopInfo{forOp, std::nullopt};

  // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
  // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
  OpBuilder boundsBuilder(forOp);
  IRRewriter rewriter(forOp.getContext());
  auto loc = forOp.getLoc();
  Value step = forOp.getStep();
  Value upperBoundUnrolled;
  Value stepUnrolled;
  bool generateEpilogueLoop = true;

  std::optional<APInt> constTripCount = forOp.getStaticTripCount();
  if (constTripCount) {
    // Constant loop bounds computation.
    int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
    int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
    int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
    if (unrollFactor == 1) {
      if (*constTripCount == 1 &&
          failed(forOp.promoteIfSingleIteration(rewriter)))
        return failure();
      return UnrolledLoopInfo{forOp, std::nullopt};
    }

    int64_t tripCountEvenMultiple =
        constTripCount->getSExtValue() -
        (constTripCount->getSExtValue() % unrollFactor);
    int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
    int64_t stepUnrolledCst = stepCst * unrollFactor;

    // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
    generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
    if (generateEpilogueLoop)
      upperBoundUnrolled = arith::ConstantOp::create(
          boundsBuilder, loc,
          boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
                                       upperBoundUnrolledCst));
    else
      upperBoundUnrolled = forOp.getUpperBound();

    // Create constant for 'stepUnrolled'.
    stepUnrolled =
        stepCst == stepUnrolledCst
            ? step
            : arith::ConstantOp::create(boundsBuilder, loc,
                                        boundsBuilder.getIntegerAttr(
                                            step.getType(), stepUnrolledCst));
  } else {
    // Dynamic loop bounds computation.
    // TODO: Add dynamic asserts for negative lb/ub/step, or
    // consider using ceilDiv from AffineApplyExpander.
    auto lowerBound = forOp.getLowerBound();
    auto upperBound = forOp.getUpperBound();
    Value diff =
        arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
    Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
    Value unrollFactorCst = arith::ConstantOp::create(
        boundsBuilder, loc,
        boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
    Value tripCountRem =
        arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
    // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
    Value tripCountEvenMultiple =
        arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
    // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step
    upperBoundUnrolled = arith::AddIOp::create(
        boundsBuilder, loc, lowerBound,
        arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
    // Scale 'step' by 'unrollFactor'.
    stepUnrolled =
        arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
  }

  UnrolledLoopInfo resultLoops;

  // Create epilogue clean up loop starting at 'upperBoundUnrolled'.
  if (generateEpilogueLoop) {
    OpBuilder epilogueBuilder(forOp->getContext());
    epilogueBuilder.setInsertionPointAfter(forOp);
    auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.clone(*forOp));
    epilogueForOp.setLowerBound(upperBoundUnrolled);

    // Update uses of loop results.
    auto results = forOp.getResults();
    auto epilogueResults = epilogueForOp.getResults();

    for (auto e : llvm::zip(results, epilogueResults)) {
      std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
    }
    epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
                               epilogueForOp.getInitArgs().size(), results);
    if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
      resultLoops.epilogueLoopOp = epilogueForOp;
  }

  // Create unrolled loop.
  forOp.setUpperBound(upperBoundUnrolled);
  forOp.setStep(stepUnrolled);

  auto iterArgs = ValueRange(forOp.getRegionIterArgs());
  auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();

  generateUnrolledLoop(
      forOp.getBody(), forOp.getInductionVar(), unrollFactor,
      [&](unsigned i, Value iv, OpBuilder b) {
        // iv' = iv + step * i;
        auto stride = arith::MulIOp::create(
            b, loc, step,
            arith::ConstantOp::create(b, loc,
                                      b.getIntegerAttr(iv.getType(), i)));
        return arith::AddIOp::create(b, loc, iv, stride);
      },
      annotateFn, iterArgs, yieldedValues);
  // Promote the loop body up if this has turned into a single iteration loop.
  if (forOp.promoteIfSingleIteration(rewriter).failed())
    resultLoops.mainLoopOp = forOp;
  return resultLoops;
}

/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
  IRRewriter rewriter(forOp.getContext());
  std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
  if (!mayBeConstantTripCount.has_value())
    return failure();
  const APInt &tripCount = *mayBeConstantTripCount;
  if (tripCount.isZero())
    return success();
  if (tripCount.getSExtValue() == 1)
    return forOp.promoteIfSingleIteration(rewriter);
  return loopUnrollByFactor(forOp, tripCount.getSExtValue());
}

/// Check if bounds of all inner loops are defined outside of `forOp`
/// and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
      return WalkResult::interrupt();

    return WalkResult::advance();
  });
  return !walkResult.wasInterrupted();
}

/// Unrolls and jams this loop by the specified factor.
LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
                                          uint64_t unrollJamFactor) {
  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");

  if (unrollJamFactor == 1)
    return success();

  // If any control operand of any inner loop of `forOp` is defined within
  // `forOp`, no unroll jam.
  if (!areInnerBoundsInvariant(forOp)) {
    LDBG() << "failed to unroll and jam: inner bounds are not invariant";
    return failure();
  }

  // Currently, for operations with results are not supported.
  if (forOp->getNumResults() > 0) {
    LDBG() << "failed to unroll and jam: unsupported loop with results";
    return failure();
  }

  // Currently, only constant trip count that divided by the unroll factor is
  // supported.
  std::optional<APInt> tripCount = forOp.getStaticTripCount();
  if (!tripCount.has_value()) {
    // If the trip count is dynamic, do not unroll & jam.
    LDBG() << "failed to unroll and jam: trip count could not be determined";
    return failure();
  }
  if (unrollJamFactor > tripCount->getZExtValue()) {
    LDBG() << "unroll and jam factor is greater than trip count, set factor to "
              "trip "
              "count";
    unrollJamFactor = tripCount->getZExtValue();
  } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
    LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
              "multiple of unroll jam factor";
    return failure();
  }

  // Nothing in the loop body other than the terminator.
  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
    return success();

  // Gather all sub-blocks to jam upon the loop being unrolled.
  JamBlockGatherer<scf::ForOp> jbg;
  jbg.walk(forOp);
  auto &subBlocks = jbg.subBlocks;

  // Collect inner loops.
  SmallVector<scf::ForOp> innerLoops;
  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });

  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
  // iteration. There are (`unrollJamFactor` - 1) iterations.
  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);

  // For any loop with iter_args, replace it with a new loop that has
  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
  // operands.
  SmallVector<scf::ForOp> newInnerLoops;
  IRRewriter rewriter(forOp.getContext());
  for (scf::ForOp oldForOp : innerLoops) {
    SmallVector<Value> dupIterOperands, dupYieldOperands;
    ValueRange oldIterOperands = oldForOp.getInits();
    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
    ValueRange oldYieldOperands =
        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
    // Get additional iterOperands, iterArgs, and yield operands. We will
    // fix iterOperands and yield operands after cloning of sub-blocks.
    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
    }
    // Create a new loop with additional iterOperands, iter_args and yield
    // operands. This new loop will take the loop body of the original loop.
    bool forOpReplaced = oldForOp == forOp;
    scf::ForOp newForOp =
        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
              return dupYieldOperands;
            }));
    newInnerLoops.push_back(newForOp);
    // `forOp` has been replaced with a new loop.
    if (forOpReplaced)
      forOp = newForOp;
    // Update `operandMaps` for `newForOp` iterArgs and results.
    ValueRange newIterArgs = newForOp.getRegionIterArgs();
    unsigned oldNumIterArgs = oldIterArgs.size();
    ValueRange newResults = newForOp.getResults();
    unsigned oldNumResults = newResults.size() / unrollJamFactor;
    assert(oldNumIterArgs == oldNumResults &&
           "oldNumIterArgs must be the same as oldNumResults");
    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
        // to those in the `i`th new set.
        operandMaps[i - 1].map(newIterArgs[j],
                               newIterArgs[i * oldNumIterArgs + j]);
        operandMaps[i - 1].map(newResults[j],
                               newResults[i * oldNumResults + j]);
      }
    }
  }

  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
  rewriter.setInsertionPoint(forOp);
  int64_t step = forOp.getConstantStep()->getSExtValue();
  auto newStep = rewriter.createOrFold<arith::MulIOp>(
      forOp.getLoc(), forOp.getStep(),
      rewriter.createOrFold<arith::ConstantOp>(
          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
  forOp.setStep(newStep);
  auto forOpIV = forOp.getInductionVar();

  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
    for (auto &subBlock : subBlocks) {
      // Builder to insert unroll-jammed bodies. Insert right at the end of
      // sub-block.
      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));

      // If the induction variable is used, create a remapping to the value for
      // this unrolled instance.
      if (!forOpIV.use_empty()) {
        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
        auto ivTag = builder.createOrFold<arith::ConstantOp>(
            forOp.getLoc(), builder.getIndexAttr(step * i));
        auto ivUnroll =
            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
        operandMaps[i - 1].map(forOpIV, ivUnroll);
      }
      // Clone the sub-block being unroll-jammed.
      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
        builder.clone(*it, operandMaps[i - 1]);
    }
    // Fix iterOperands and yield op operands of newly created loops.
    for (auto newForOp : newInnerLoops) {
      unsigned oldNumIterOperands =
          newForOp.getNumRegionIterArgs() / unrollJamFactor;
      unsigned numControlOperands = newForOp.getNumControlOperands();
      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
      assert(oldNumIterOperands == oldNumYieldOperands &&
             "oldNumIterOperands must be the same as oldNumYieldOperands");
      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
        // The `i`th duplication of an old iterOperand or yield op operand
        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
        // if such mapped value exists.
        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
                            operandMaps[i - 1].lookupOrDefault(
                                newForOp.getOperand(numControlOperands + j)));
        yieldOp.setOperand(
            i * oldNumYieldOperands + j,
            operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
      }
    }
  }

  // Promote the loop body up if this has turned into a single iteration loop.
  (void)forOp.promoteIfSingleIteration(rewriter);
  return success();
}

static Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter,
                                                  Location loc, OpFoldResult lb,
                                                  OpFoldResult ub,
                                                  OpFoldResult step) {
  Range normalizedLoopBounds;
  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
  AffineExpr s0, s1, s2;
  bindSymbols(rewriter.getContext(), s0, s1, s2);
  AffineExpr e = (s1 - s0).ceilDiv(s2);
  normalizedLoopBounds.size =
      affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
  return normalizedLoopBounds;
}

Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
                                     OpFoldResult lb, OpFoldResult ub,
                                     OpFoldResult step) {
  if (getType(lb).isIndex()) {
    return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
  }
  // For non-index types, generate `arith` instructions
  // Check if the loop is already known to have a constant zero lower bound or
  // a constant one step.
  bool isZeroBased = false;
  if (auto lbCst = getConstantIntValue(lb))
    isZeroBased = lbCst.value() == 0;

  bool isStepOne = false;
  if (auto stepCst = getConstantIntValue(step))
    isStepOne = stepCst.value() == 1;

  Type rangeType = getType(lb);
  assert(rangeType == getType(ub) && rangeType == getType(step) &&
         "expected matching types");

  // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
  // assuming the step is strictly positive.  Update the bounds and the step
  // of the loop to go from 0 to the number of iterations, if necessary.
  if (isZeroBased && isStepOne)
    return {lb, ub, step};

  OpFoldResult diff = ub;
  if (!isZeroBased) {
    diff = rewriter.createOrFold<arith::SubIOp>(
        loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
        getValueOrCreateConstantIntOp(rewriter, loc, lb));
  }
  OpFoldResult newUpperBound = diff;
  if (!isStepOne) {
    newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
        loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
        getValueOrCreateConstantIntOp(rewriter, loc, step));
  }

  OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
  OpFoldResult newStep = rewriter.getOneAttr(rangeType);

  return {newLowerBound, newUpperBound, newStep};
}

static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
                                                     Location loc,
                                                     Value normalizedIv,
                                                     OpFoldResult origLb,
                                                     OpFoldResult origStep) {
  AffineExpr d0, s0, s1;
  bindSymbols(rewriter.getContext(), s0, s1);
  bindDims(rewriter.getContext(), d0);
  AffineExpr e = d0 * s1 + s0;
  OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
      rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
  Value denormalizedIvVal =
      getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
  SmallPtrSet<Operation *, 1> preservedUses;
  // If an `affine.apply` operation is generated for denormalization, the use
  // of `origLb` in those ops must not be replaced. These arent not generated
  // when `origLb == 0` and `origStep == 1`.
  if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
    if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
      preservedUses.insert(preservedUse);
    }
  }
  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
}

void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
                                        Value normalizedIv, OpFoldResult origLb,
                                        OpFoldResult origStep) {
  if (getType(origLb).isIndex()) {
    return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
                                                    origLb, origStep);
  }
  Value denormalizedIv;
  SmallPtrSet<Operation *, 2> preserve;
  bool isStepOne = isOneInteger(origStep);
  bool isZeroBased = isZeroInteger(origLb);

  Value scaled = normalizedIv;
  if (!isStepOne) {
    Value origStepValue =
        getValueOrCreateConstantIntOp(rewriter, loc, origStep);
    scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
    preserve.insert(scaled.getDefiningOp());
  }
  denormalizedIv = scaled;
  if (!isZeroBased) {
    Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
    denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
    preserve.insert(denormalizedIv.getDefiningOp());
  }

  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}

static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
                                        ArrayRef<OpFoldResult> values) {
  assert(!values.empty() && "unexecpted empty array");
  AffineExpr s0, s1;
  bindSymbols(rewriter.getContext(), s0, s1);
  AffineExpr mul = s0 * s1;
  OpFoldResult products = rewriter.getIndexAttr(1);
  for (auto v : values) {
    products = affine::makeComposedFoldedAffineApply(
        rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
  }
  return products;
}

/// Helper function to multiply a sequence of values.
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
                                       ArrayRef<Value> values) {
  assert(!values.empty() && "unexpected empty list");
  if (getType(values.front()).isIndex()) {
    SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
    OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
    return getValueOrCreateConstantIndexOp(rewriter, loc, product);
  }
  std::optional<Value> productOf;
  for (auto v : values) {
    auto vOne = getConstantIntValue(v);
    if (vOne && vOne.value() == 1)
      continue;
    if (productOf)
      productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
                      .getResult();
    else
      productOf = v;
  }
  if (!productOf) {
    productOf = arith::ConstantOp::create(
                    rewriter, loc, rewriter.getOneAttr(getType(values.front())))
                    .getResult();
  }
  return productOf.value();
}

/// For each original loop, the value of the
/// induction variable can be obtained by dividing the induction variable of
/// the linearized loop by the total number of iterations of the loops nested
/// in it modulo the number of iterations in this loop (remove the values
/// related to the outer loops):
///   iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
/// Compute these iteratively from the innermost loop by creating a "running
/// quotient" of division by the range.
static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
                             Value linearizedIv, ArrayRef<Value> ubs) {

  if (linearizedIv.getType().isIndex()) {
    Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
        rewriter, loc, linearizedIv, ubs);
    auto resultVals = llvm::map_to_vector(
        delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
    return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
  }

  SmallVector<Value> delinearizedIvs(ubs.size());
  SmallPtrSet<Operation *, 2> preservedUsers;

  llvm::BitVector isUbOne(ubs.size());
  for (auto [index, ub] : llvm::enumerate(ubs)) {
    auto ubCst = getConstantIntValue(ub);
    if (ubCst && ubCst.value() == 1)
      isUbOne.set(index);
  }

  // Prune the lead ubs that are all ones.
  unsigned numLeadingOneUbs = 0;
  for (auto [index, ub] : llvm::enumerate(ubs)) {
    if (!isUbOne.test(index)) {
      break;
    }
    delinearizedIvs[index] = arith::ConstantOp::create(
        rewriter, loc, rewriter.getZeroAttr(ub.getType()));
    numLeadingOneUbs++;
  }

  Value previous = linearizedIv;
  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
    unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
    if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
      previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
      preservedUsers.insert(previous.getDefiningOp());
    }
    Value iv = previous;
    if (i != e - 1) {
      if (!isUbOne.test(idx)) {
        iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
        preservedUsers.insert(iv.getDefiningOp());
      } else {
        iv = arith::ConstantOp::create(
            rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType()));
      }
    }
    delinearizedIvs[idx] = iv;
  }
  return {delinearizedIvs, preservedUsers};
}

LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
                                  MutableArrayRef<scf::ForOp> loops) {
  if (loops.size() < 2)
    return failure();

  scf::ForOp innermost = loops.back();
  scf::ForOp outermost = loops.front();

  // 1. Make sure all loops iterate from 0 to upperBound with step 1.  This
  // allows the following code to assume upperBound is the number of iterations.
  for (auto loop : loops) {
    OpBuilder::InsertionGuard g(rewriter);
    rewriter.setInsertionPoint(outermost);
    Value lb = loop.getLowerBound();
    Value ub = loop.getUpperBound();
    Value step = loop.getStep();
    auto newLoopRange =
        emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);

    rewriter.modifyOpInPlace(loop, [&]() {
      loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
                                                       newLoopRange.offset));
      loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
                                                       newLoopRange.size));
      loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
                                                 newLoopRange.stride));
    });
    rewriter.setInsertionPointToStart(innermost.getBody());
    denormalizeInductionVariable(rewriter, loop.getLoc(),
                                 loop.getInductionVar(), lb, step);
  }

  // 2. Emit code computing the upper bound of the coalesced loop as product
  // of the number of iterations of all loops.
  OpBuilder::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(outermost);
  Location loc = outermost.getLoc();
  SmallVector<Value> upperBounds = llvm::map_to_vector(
      loops, [](auto loop) { return loop.getUpperBound(); });
  Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
  outermost.setUpperBound(upperBound);

  rewriter.setInsertionPointToStart(innermost.getBody());
  auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
      rewriter, loc, outermost.getInductionVar(), upperBounds);
  rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
                                preservedUsers);

  for (int i = loops.size() - 1; i > 0; --i) {
    auto outerLoop = loops[i - 1];
    auto innerLoop = loops[i];

    Operation *innerTerminator = innerLoop.getBody()->getTerminator();
    auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
    assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
    for (Value &yieldedVal : yieldedVals) {
      // The yielded value may be an iteration argument of the inner loop
      // which is about to be inlined.
      auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
      if (iter != innerLoop.getRegionIterArgs().end()) {
        unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
        // `outerLoop` iter args identical to the `innerLoop` init args.
        assert(iterArgIndex < innerLoop.getInitArgs().size());
        yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
      }
    }
    rewriter.eraseOp(innerTerminator);

    SmallVector<Value> innerBlockArgs;
    innerBlockArgs.push_back(delinearizeIvs[i]);
    llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
    rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
                               Block::iterator(innerLoop), innerBlockArgs);
    rewriter.replaceOp(innerLoop, yieldedVals);
  }
  return success();
}

LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
  if (loops.empty()) {
    return failure();
  }
  IRRewriter rewriter(loops.front().getContext());
  return coalesceLoops(rewriter, loops);
}

LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
  LogicalResult result(failure());
  SmallVector<scf::ForOp> loops;
  getPerfectlyNestedLoops(loops, op);

  // Look for a band of loops that can be coalesced, i.e. perfectly nested
  // loops with bounds defined above some loop.

  // 1. For each loop, find above which parent loop its bounds operands are
  // defined.
  SmallVector<unsigned> operandsDefinedAbove(loops.size());
  for (unsigned i = 0, e = loops.size(); i < e; ++i) {
    operandsDefinedAbove[i] = i;
    for (unsigned j = 0; j < i; ++j) {
      SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
                                           loops[i].getUpperBound(),
                                           loops[i].getStep()};
      if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
        operandsDefinedAbove[i] = j;
        break;
      }
    }
  }

  // 2. For each inner loop check that the iter_args for the immediately outer
  // loop are the init for the immediately inner loop and that the yields of the
  // return of the inner loop is the yield for the immediately outer loop. Keep
  // track of where the chain starts from for each loop.
  SmallVector<unsigned> iterArgChainStart(loops.size());
  iterArgChainStart[0] = 0;
  for (unsigned i = 1, e = loops.size(); i < e; ++i) {
    // By default set the start of the chain to itself.
    iterArgChainStart[i] = i;
    auto outerloop = loops[i - 1];
    auto innerLoop = loops[i];
    if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
      continue;
    }
    if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
      continue;
    }
    auto outerloopTerminator = outerloop.getBody()->getTerminator();
    if (!llvm::equal(outerloopTerminator->getOperands(),
                     innerLoop.getResults())) {
      continue;
    }
    iterArgChainStart[i] = iterArgChainStart[i - 1];
  }

  // 3. Identify bands of loops such that the operands of all of them are
  // defined above the first loop in the band.  Traverse the nest bottom-up
  // so that modifications don't invalidate the inner loops.
  for (unsigned end = loops.size(); end > 0; --end) {
    unsigned start = 0;
    for (; start < end - 1; ++start) {
      auto maxPos =
          *std::max_element(std::next(operandsDefinedAbove.begin(), start),
                            std::next(operandsDefinedAbove.begin(), end));
      if (maxPos > start)
        continue;
      if (iterArgChainStart[end - 1] > start)
        continue;
      auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
      if (succeeded(coalesceLoops(band)))
        result = success();
      break;
    }
    // If a band was found and transformed, keep looking at the loops above
    // the outermost transformed loop.
    if (start != end - 1)
      end = start + 1;
  }
  return result;
}

void mlir::collapseParallelLoops(
    RewriterBase &rewriter, scf::ParallelOp loops,
    ArrayRef<std::vector<unsigned>> combinedDimensions) {
  OpBuilder::InsertionGuard g(rewriter);
  rewriter.setInsertionPoint(loops);
  Location loc = loops.getLoc();

  // Presort combined dimensions.
  auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
  for (auto &dims : sortedDimensions)
    llvm::sort(dims);

  // Normalize ParallelOp's iteration pattern.
  SmallVector<Value, 3> normalizedUpperBounds;
  for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
    OpBuilder::InsertionGuard g2(rewriter);
    rewriter.setInsertionPoint(loops);
    Value lb = loops.getLowerBound()[i];
    Value ub = loops.getUpperBound()[i];
    Value step = loops.getStep()[i];
    auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
    normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
        rewriter, loops.getLoc(), newLoopRange.size));

    rewriter.setInsertionPointToStart(loops.getBody());
    denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
                                 step);
  }

  // Combine iteration spaces.
  SmallVector<Value, 3> lowerBounds, upperBounds, steps;
  auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
  auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
  for (auto &sortedDimension : sortedDimensions) {
    Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1);
    for (auto idx : sortedDimension) {
      newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
                                            normalizedUpperBounds[idx]);
    }
    lowerBounds.push_back(cst0);
    steps.push_back(cst1);
    upperBounds.push_back(newUpperBound);
  }

  // Create new ParallelLoop with conversions to the original induction values.
  // The loop below uses divisions to get the relevant range of values in the
  // new induction value that represent each range of the original induction
  // value. The remainders then determine based on that range, which iteration
  // of the original induction value this represents. This is a normalized value
  // that is un-normalized already by the previous logic.
  auto newPloop = scf::ParallelOp::create(
      rewriter, loc, lowerBounds, upperBounds, steps,
      [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
        for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
          Value previous = ploopIVs[i];
          unsigned numberCombinedDimensions = combinedDimensions[i].size();
          // Iterate over all except the last induction value.
          for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
            unsigned idx = combinedDimensions[i][j];

            // Determine the current induction value's current loop iteration
            Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
                                              normalizedUpperBounds[idx]);
            replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
                                       loops.getRegion());

            // Remove the effect of the current induction value to prepare for
            // the next value.
            previous = arith::DivSIOp::create(insideBuilder, loc, previous,
                                              normalizedUpperBounds[idx]);
          }

          // The final induction value is just the remaining value.
          unsigned idx = combinedDimensions[i][0];
          replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx),
                                     previous, loops.getRegion());
        }
      });

  // Replace the old loop with the new loop.
  loops.getBody()->back().erase();
  newPloop.getBody()->getOperations().splice(
      Block::iterator(newPloop.getBody()->back()),
      loops.getBody()->getOperations());
  loops.erase();
}

// Hoist the ops within `outer` that appear before `inner`.
// Such ops include the ops that have been introduced by parametric tiling.
// Ops that come from triangular loops (i.e. that belong to the program slice
// rooted at `outer`) and ops that have side effects cannot be hoisted.
// Return failure when any op fails to hoist.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
  SetVector<Operation *> forwardSlice;
  ForwardSliceOptions options;
  options.filter = [&inner](Operation *op) {
    return op != inner.getOperation();
  };
  getForwardSlice(outer.getInductionVar(), &forwardSlice, options);
  LogicalResult status = success();
  SmallVector<Operation *, 8> toHoist;
  for (auto &op : outer.getBody()->without_terminator()) {
    // Stop when encountering the inner loop.
    if (&op == inner.getOperation())
      break;
    // Skip over non-hoistable ops.
    if (forwardSlice.count(&op) > 0) {
      status = failure();
      continue;
    }
    // Skip intermediate scf::ForOp, these are not considered a failure.
    if (isa<scf::ForOp>(op))
      continue;
    // Skip other ops with regions.
    if (op.getNumRegions() > 0) {
      status = failure();
      continue;
    }
    // Skip if op has side effects.
    // TODO: loads to immutable memory regions are ok.
    if (!isMemoryEffectFree(&op)) {
      status = failure();
      continue;
    }
    toHoist.push_back(&op);
  }
  auto *outerForOp = outer.getOperation();
  for (auto *op : toHoist)
    op->moveBefore(outerForOp);
  return status;
}

// Traverse the interTile and intraTile loops and try to hoist ops such that
// bands of perfectly nested loops are isolated.
// Return failure if either perfect interTile or perfect intraTile bands cannot
// be formed.
static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
  LogicalResult status = success();
  const Loops &interTile = tileLoops.first;
  const Loops &intraTile = tileLoops.second;
  auto size = interTile.size();
  assert(size == intraTile.size());
  if (size <= 1)
    return success();
  for (unsigned s = 1; s < size; ++s)
    status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
                               : failure();
  for (unsigned s = 1; s < size; ++s)
    status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
                               : failure();
  return status;
}

/// Collect perfectly nested loops starting from `rootForOps`.  Loops are
/// perfectly nested if each loop is the first and only non-terminator operation
/// in the parent loop.  Collect at most `maxLoops` loops and append them to
/// `forOps`.
template <typename T>
static void getPerfectlyNestedLoopsImpl(
    SmallVectorImpl<T> &forOps, T rootForOp,
    unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
  for (unsigned i = 0; i < maxLoops; ++i) {
    forOps.push_back(rootForOp);
    Block &body = rootForOp.getRegion().front();
    if (body.begin() != std::prev(body.end(), 2))
      return;

    rootForOp = dyn_cast<T>(&body.front());
    if (!rootForOp)
      return;
  }
}

static Loops stripmineSink(scf::ForOp forOp, Value factor,
                           ArrayRef<scf::ForOp> targets) {
  assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
  auto originalStep = forOp.getStep();
  auto iv = forOp.getInductionVar();

  OpBuilder b(forOp);
  forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));

  Loops innerLoops;
  for (auto t : targets) {
    assert(!t.getUnsignedCmp() && "unsigned loops are not supported");

    // Save information for splicing ops out of t when done
    auto begin = t.getBody()->begin();
    auto nOps = t.getBody()->getOperations().size();

    // Insert newForOp before the terminator of `t`.
    auto b = OpBuilder::atBlockTerminator((t.getBody()));
    Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
    Value ub =
        arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);

    // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
    auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
    newForOp.getBody()->getOperations().splice(
        newForOp.getBody()->getOperations().begin(),
        t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
    replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
                               newForOp.getRegion());

    innerLoops.push_back(newForOp);
  }

  return innerLoops;
}

// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
// Returns the new for operation, nested immediately under `target`.
template <typename SizeType>
static scf::ForOp stripmineSink(scf::ForOp forOp, SizeType factor,
                                scf::ForOp target) {
  // TODO: Use cheap structural assertions that targets are nested under
  // forOp and that targets are not nested under each other when DominanceInfo
  // exposes the capability. It seems overkill to construct a whole function
  // dominance tree at this point.
  auto res = stripmineSink(forOp, factor, ArrayRef<scf::ForOp>(target));
  assert(res.size() == 1 && "Expected 1 inner forOp");
  return res[0];
}

SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
                                 ArrayRef<Value> sizes,
                                 ArrayRef<scf::ForOp> targets) {
  SmallVector<SmallVector<scf::ForOp, 8>, 8> res;
  SmallVector<scf::ForOp, 8> currentTargets(targets);
  for (auto it : llvm::zip(forOps, sizes)) {
    auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
    res.push_back(step);
    currentTargets = step;
  }
  return res;
}

Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
                 scf::ForOp target) {
  SmallVector<scf::ForOp, 8> res;
  for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
    res.push_back(llvm::getSingleElement(loops));
  return res;
}

Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
  // Collect perfectly nested loops.  If more size values provided than nested
  // loops available, truncate `sizes`.
  SmallVector<scf::ForOp, 4> forOps;
  forOps.reserve(sizes.size());
  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
  if (forOps.size() < sizes.size())
    sizes = sizes.take_front(forOps.size());

  return ::tile(forOps, sizes, forOps.back());
}

void mlir::getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
                                   scf::ForOp root) {
  getPerfectlyNestedLoopsImpl(nestedLoops, root);
}

TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
                                       ArrayRef<int64_t> sizes) {
  // Collect perfectly nested loops.  If more size values provided than nested
  // loops available, truncate `sizes`.
  SmallVector<scf::ForOp, 4> forOps;
  forOps.reserve(sizes.size());
  getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
  if (forOps.size() < sizes.size())
    sizes = sizes.take_front(forOps.size());

  // Compute the tile sizes such that i-th outer loop executes size[i]
  // iterations.  Given that the loop current executes
  //   numIterations = ceildiv((upperBound - lowerBound), step)
  // iterations, we need to tile with size ceildiv(numIterations, size[i]).
  SmallVector<Value, 4> tileSizes;
  tileSizes.reserve(sizes.size());
  for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
    assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");

    auto forOp = forOps[i];
    OpBuilder builder(forOp);
    auto loc = forOp.getLoc();
    Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
                                       forOp.getLowerBound());
    Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep());
    Value iterationsPerBlock =
        ceilDivPositive(builder, loc, numIterations, sizes[i]);
    tileSizes.push_back(iterationsPerBlock);
  }

  // Call parametric tiling with the given sizes.
  auto intraTile = tile(forOps, tileSizes, forOps.back());
  TileLoops tileLoops = std::make_pair(forOps, intraTile);

  // TODO: for now we just ignore the result of band isolation.
  // In the future, mapping decisions may be impacted by the ability to
  // isolate perfectly nested bands.
  (void)tryIsolateBands(tileLoops);

  return tileLoops;
}

scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
                                                      scf::ForallOp source,
                                                      RewriterBase &rewriter) {
  unsigned numTargetOuts = target.getNumResults();
  unsigned numSourceOuts = source.getNumResults();

  // Create fused shared_outs.
  SmallVector<Value> fusedOuts;
  llvm::append_range(fusedOuts, target.getOutputs());
  llvm::append_range(fusedOuts, source.getOutputs());

  // Create a new scf.forall op after the source loop.
  rewriter.setInsertionPointAfter(source);
  scf::ForallOp fusedLoop = scf::ForallOp::create(
      rewriter, source.getLoc(), source.getMixedLowerBound(),
      source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
      source.getMapping());

  // Map control operands.
  IRMapping mapping;
  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());

  // Map shared outs.
  mapping.map(target.getRegionIterArgs(),
              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
  mapping.map(source.getRegionIterArgs(),
              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));

  // Append everything except the terminator into the fused operation.
  rewriter.setInsertionPointToStart(fusedLoop.getBody());
  for (Operation &op : target.getBody()->without_terminator())
    rewriter.clone(op, mapping);
  for (Operation &op : source.getBody()->without_terminator())
    rewriter.clone(op, mapping);

  // Fuse the old terminator in_parallel ops into the new one.
  scf::InParallelOp targetTerm = target.getTerminator();
  scf::InParallelOp sourceTerm = source.getTerminator();
  scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
  rewriter.setInsertionPointToStart(fusedTerm.getBody());
  for (Operation &op : targetTerm.getYieldingOps())
    rewriter.clone(op, mapping);
  for (Operation &op : sourceTerm.getYieldingOps())
    rewriter.clone(op, mapping);

  // Replace old loops by substituting their uses by results of the fused loop.
  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));

  return fusedLoop;
}

scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
                                                scf::ForOp source,
                                                RewriterBase &rewriter) {
  assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
         "incompatible signedness");
  unsigned numTargetOuts = target.getNumResults();
  unsigned numSourceOuts = source.getNumResults();

  // Create fused init_args, with target's init_args before source's init_args.
  SmallVector<Value> fusedInitArgs;
  llvm::append_range(fusedInitArgs, target.getInitArgs());
  llvm::append_range(fusedInitArgs, source.getInitArgs());

  // Create a new scf.for op after the source loop (with scf.yield terminator
  // (without arguments) only in case its init_args is empty).
  rewriter.setInsertionPointAfter(source);
  scf::ForOp fusedLoop = scf::ForOp::create(
      rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
      source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
      source.getUnsignedCmp());

  // Map original induction variables and operands to those of the fused loop.
  IRMapping mapping;
  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
  mapping.map(target.getRegionIterArgs(),
              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
  mapping.map(source.getRegionIterArgs(),
              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));

  // Merge target's body into the new (fused) for loop and then source's body.
  rewriter.setInsertionPointToStart(fusedLoop.getBody());
  for (Operation &op : target.getBody()->without_terminator())
    rewriter.clone(op, mapping);
  for (Operation &op : source.getBody()->without_terminator())
    rewriter.clone(op, mapping);

  // Build fused yield results by appropriately mapping original yield operands.
  SmallVector<Value> yieldResults;
  for (Value operand : target.getBody()->getTerminator()->getOperands())
    yieldResults.push_back(mapping.lookupOrDefault(operand));
  for (Value operand : source.getBody()->getTerminator()->getOperands())
    yieldResults.push_back(mapping.lookupOrDefault(operand));
  if (!yieldResults.empty())
    scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);

  // Replace old loops by substituting their uses by results of the fused loop.
  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));

  return fusedLoop;
}

FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
                                                 scf::ForallOp forallOp) {
  SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
  SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
  SmallVector<OpFoldResult> steps = forallOp.getMixedStep();

  if (forallOp.isNormalized())
    return forallOp;

  OpBuilder::InsertionGuard g(rewriter);
  auto loc = forallOp.getLoc();
  rewriter.setInsertionPoint(forallOp);
  SmallVector<OpFoldResult> newUbs;
  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
    Range normalizedLoopParams =
        emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
    newUbs.push_back(normalizedLoopParams.size);
  }
  (void)foldDynamicIndexList(newUbs);

  // Use the normalized builder since the lower bounds are always 0 and the
  // steps are always 1.
  auto normalizedForallOp = scf::ForallOp::create(
      rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
      [](OpBuilder &, Location, ValueRange) {});

  rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
                              normalizedForallOp.getBodyRegion(),
                              normalizedForallOp.getBodyRegion().begin());
  // Remove the original empty block in the new loop.
  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());

  rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
  // Update the users of the original loop variables.
  for (auto [idx, iv] :
       llvm::enumerate(normalizedForallOp.getInductionVars())) {
    auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
    auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
    denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
  }

  rewriter.replaceOp(forallOp, normalizedForallOp);
  return normalizedForallOp;
}

bool mlir::isPerfectlyNestedForLoops(
    MutableArrayRef<LoopLikeOpInterface> loops) {
  assert(!loops.empty() && "unexpected empty loop nest");
  if (loops.size() == 1)
    return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
  for (auto [outerLoop, innerLoop] :
       llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
    auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
    auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
    if (!outerFor || !innerFor)
      return false;
    auto outerBBArgs = outerFor.getRegionIterArgs();
    auto innerIterArgs = innerFor.getInitArgs();
    if (outerBBArgs.size() != innerIterArgs.size())
      return false;

    for (auto [outerBBArg, innerIterArg] :
         llvm::zip_equal(outerBBArgs, innerIterArgs)) {
      if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
          innerIterArg != outerBBArg)
        return false;
    }

    ValueRange outerYields =
        cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
    ValueRange innerResults = innerFor.getResults();
    if (outerYields.size() != innerResults.size())
      return false;
    for (auto [outerYield, innerResult] :
         llvm::zip_equal(outerYields, innerResults)) {
      if (!llvm::hasSingleElement(innerResult.getUses()) ||
          outerYield != innerResult)
        return false;
    }
  }
  return true;
}
